package com.nowcoder.topic.dp.hard;

import java.math.BigInteger;
import java.util.ArrayList;

/**
 * NC328 矩阵取数游戏
 * @author d3y1
 */
public class NC328{
    private final int MOD = 1000000007;
    private final int BASE = 2;

    /**
     * 代码中的类名、方法名、参数名已经指定，请勿修改，直接返回方法规定的值即可
     *
     * 相似 -> NC327 取数游戏
     * @param matrix int整型ArrayList<ArrayList<>>
     * @return int整型
     */
    public int matrixScore (ArrayList<ArrayList<Integer>> matrix) {
//        return solution1(matrix);
        return solution2(matrix);
    }

    /**
     * 动态规划 + 滑动窗口: 值大溢出
     *
     * dp[t][i][j]表示矩阵第t行剩余列区间[i,j]取数后的最大得分
     *
     * pow[i]表示第i次取数的权值
     * gap表示窗口间隔
     *
     * j = i+gap
     * gap=0时: i=j, 剩余列区间[i,i], 该行仅剩最后1个元素, 当为最后一次(第m次)取数, 权值为pow[m]
     * gap>0时: i<j, 剩余列区间[i,j], 该行仅剩最后(gap+1)个元素, 已经取走(m-(gap+1))个元素, 当前为第((m-(gap+1))+1)取数, 即第(m-gap)次取数, 权值为pow[m-gap]
     *
     * @param matrix
     * @return
     */
    private int solution1(ArrayList<ArrayList<Integer>> matrix){
        int n = matrix.size();
        int m = matrix.get(0).size();

        long[][][] dp = new long[n+1][m+1][m+1];

        long result = 0;
        for(int t=1; t<=n; t++){
            for(int gap=0; gap<m; gap++){
                for(int i=1,j=i+gap; j<=m; i++,j++){
                    if(i == j){
                        dp[t][i][j] = ((long)matrix.get(t-1).get(i-1)*(1L<<(m-gap)));
                    }else{
                        dp[t][i][j] = (Math.max((long)matrix.get(t-1).get(i-1)*(1L<<(m-gap))+dp[t][i+1][j], (long)matrix.get(t-1).get(j-1)*(1L<<(m-gap))+dp[t][i][j-1]));
                    }
                }
            }
            result = (result+(dp[t][1][m] % MOD)) % MOD;
        }

        return (int)result;
    }

    /**
     * 动态规划 + 滑动窗口 + BigInteger
     *
     * dp[t][i][j]表示矩阵第t行剩余列区间[i,j]取数后的最大得分
     *
     * pow[i]表示第i次取数的权值
     * gap表示窗口间隔
     *
     * j = i+gap
     * gap=0时: i=j, 剩余列区间[i,i], 该行仅剩最后1个元素, 当为最后一次(第m次)取数, 权值为pow[m]
     * gap>0时: i<j, 剩余列区间[i,j], 该行仅剩最后(gap+1)个元素, 已经取走(m-(gap+1))个元素, 当前为第((m-(gap+1))+1)取数, 即第(m-gap)次取数, 权值为pow[m-gap]
     *
     * @param matrix
     * @return
     */
    private int solution2(ArrayList<ArrayList<Integer>> matrix){
        int n = matrix.size();
        int m = matrix.get(0).size();

        BigInteger[][][] dp = new BigInteger[n+1][m+1][m+1];

        // init pow
        BigInteger[] pow = new BigInteger[m+1];
        pow[0] = new BigInteger("1");
        for(int i=1; i<=m; i++){
            pow[i] = pow[i-1].multiply(BigInteger.valueOf(BASE));
        }

        int result = 0;
        BigInteger left,right;
        // 各行之间互相独立
        for(int t=1; t<=n; t++){
            // 滑动窗口
            for(int gap=0; gap<m; gap++){
                for(int i=1,j=i+gap; j<=m; i++,j++){
                    if(i == j){
                        dp[t][i][j] = (pow[m-gap].multiply(BigInteger.valueOf(matrix.get(t-1).get(i-1)))) ;
                    }else{
                        left = dp[t][i+1][j].add(pow[m-gap].multiply(BigInteger.valueOf(matrix.get(t-1).get(i-1))));
                        right = dp[t][i][j-1].add(pow[m-gap].multiply(BigInteger.valueOf(matrix.get(t-1).get(j-1))));

                        if(left.compareTo(right) > 0){
                            dp[t][i][j] = left;
                        }else{
                            dp[t][i][j] = right;
                        }
                    }
                }
            }
            result = (result+(dp[t][1][m].mod(BigInteger.valueOf(MOD)).intValue())) % MOD;
        }

        return result;
    }
}